/**
 * Copyright 2019-2022 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "framework/graph/core/optimize/fusion/pattern_define.h"

#include <algorithm>

#include "infra/base/securestl.h"
#include "infra/base/assertion.h"

#include "framework/graph/core/cgraph/compute_graph.h"
#include "framework/graph/core/cgraph/graph_modifier.h"
#include "framework/graph/core/node/node.h"
#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/core/node/node_walker.h"
#include "framework/graph/core/node/node_visitor.h"
#include "framework/graph/core/edge/edge.h"
#include "framework/graph/core/edge/edge_visitor.h"
#include "framework/graph/core/edge/endpoint.h"

using namespace std;

namespace hiai {
class PatternMappingImpl : public PatternMapping {
public:
    explicit PatternMappingImpl(const ge::ComputeGraph& graph) : out_(0), ownerGraph_(graph)
    {
    }

    ~PatternMappingImpl() override = default;

    ge::Node* Node(const Id& id) override
    {
        if (mapping_.count(id) > 0) {
            return const_cast<ge::Node*>(mapping_[id]);
        }

        return nullptr;
    }

    std::unique_ptr<ge::GraphSrcBoundary> ToGraphSrcBoundary() const override
    {
        std::vector<ge::Endpoint> outEndpionts;
        mapping_.at(out_)->ROLE(NodeWalker).ListOutDataEdges([&outEndpionts](ge::Edge& edge) {
            outEndpionts.push_back(edge.Src());
            return SUCCESS;
        });

        return make_unique_nothrow<ge::GraphSrcBoundary>(inEndpoints_, outEndpionts, ownerGraph_.ROLE(GraphModifier));
    }

public:
    void Insert(const Id& id, const ge::Node* node)
    {
        mapping_[id] = node;
    }

    void SetInEndpoint(const ge::Endpoint& ept)
    {
        if (std::find_if(inEndpoints_.begin(), inEndpoints_.end(), [&ept](ge::Endpoint& endpoint) {
                return (&(ept.Node()) == &(endpoint.Node())) && ept.Idx() == endpoint.Idx();
            }) == inEndpoints_.end()) {
            inEndpoints_.push_back(ept);
        }
    }

    void SetOutNode(const Id& id)
    {
        out_ = id;
    }

private:
    std::map<Id, const ge::Node*> mapping_;
    std::vector<ge::Endpoint> inEndpoints_;
    Id out_;
    const ge::ComputeGraph& ownerGraph_;
};

PatternDefine::PatternDefine(std::vector<PatternNode> nodes, PatternInputs inputs = {}, PatternOutput output = {})
    : patternNodes_(nodes), patternInputs_(inputs), patternOutput_(output)
{
    for (auto& patternNode : patternNodes_) {
        patternNodeMap_[patternNode.id_] = &patternNode;
    }
}

std::vector<std::string> PatternDefine::AttentionTypes() const
{
    std::vector<std::string> types;

    for (const Id& outId : patternOutput_) {
        const PatternNode& outNode = patternNodes_[outId];
        types.insert(types.cend(), outNode.types_.cbegin(), outNode.types_.cend());
    }

    return types;
}

namespace {
bool IsMatched(PatternNode* patternNode, const string& type, std::vector<Id>& ids, Id id)
{
    if (std::find(ids.begin(), ids.end(), id) != ids.end() && !patternNode->repeatable_) {
        return false;
    }

    if (patternNode->inputs_.size() == 0 && patternNode->types_.size() == 0) {
        return true;
    }

    if (std::find(patternNode->types_.begin(), patternNode->types_.end(), type) == patternNode->types_.end()) {
        return false;
    }

    return true;
}

void SetInputInEndpoints(ge::Node* node, std::unique_ptr<PatternMappingImpl>& patternMapping)
{
    node->ROLE(NodeWalker).ListInDataEdgesNonConst([&patternMapping](ge::Edge& edge) {
        patternMapping->SetInEndpoint(edge.Dst());
        return SUCCESS;
    });
}

void SetPatternMappingInfo(std::unique_ptr<PatternMappingImpl>& patternMapping,
    std::map<Id, PatternNode*>& patternNodeMap, PatternInputs& patternInputs, Id id, ge::Edge& edge)
{
    if (patternMapping->Node(id) == nullptr) {
        patternMapping->Insert(id, &(edge.SrcNode()));

        if (patternNodeMap[id]->inputs_.size() != 0 &&
            std::find(patternInputs.begin(), patternInputs.end(), id) == patternInputs.end()) {
            return;
        }

        if (patternNodeMap[id]->types_.size() == 0) { // any node
            patternMapping->SetInEndpoint(edge.Dst()); // support double set
        } else {
            SetInputInEndpoints(&(edge.SrcNode()), patternMapping);
        }
    }
}

Status MatchNode(Id currentId, std::map<Id, PatternNode*>& patternNodeMap, PatternInputs& patternInputs,
    std::unique_ptr<PatternMappingImpl>& patternMapping, std::vector<Id>& matchedId)
{
    PatternNode* patternNode = patternNodeMap[currentId];
    ge::Node* node = patternMapping->Node(currentId);
    if (node == nullptr) {
        return FAILURE;
    }
    if (patternNode->inputs_.empty()) {
        return SUCCESS;
    }

    if (patternNode->checker_ != nullptr) {
        return (*(patternNode->checker_))(*node);
    }

    node->ROLE(NodeWalker).ListInDataEdges([&](ge::Edge& edge) {
        for (const auto& id : patternNode->inputs_) {
            if (IsMatched(patternNodeMap[id], edge.SrcNode().ROLE(NodeSpec).Type(), matchedId, id)) {
                SetPatternMappingInfo(patternMapping, patternNodeMap, patternInputs, id, edge);
                matchedId.push_back(id);
                break;
            }
        }
        return hiai::SUCCESS;
    });

    HIAI_EXPECT_EXEC(matchedId.size() == patternNode->inputs_.size() ? hiai::SUCCESS : hiai::FAILURE);

    return SUCCESS;
}

std::vector<Id> AllAnyNode(std::vector<PatternNode>& patternNodes, PatternInputs& patternInputs)
{
    std::vector<Id> anyNodeIds;
    std::for_each(patternNodes.begin(), patternNodes.end(), [&anyNodeIds, &patternInputs](PatternNode& patternNode) {
        if (patternNode.types_.size() == 0 &&
            std::find(patternInputs.begin(), patternInputs.end(), patternNode.id_) == patternInputs.end()) {
            anyNodeIds.push_back(patternNode.id_);
        }
    });

    return anyNodeIds;
}

std::map<Id, std::vector<Id>> AnyNodeToInputsMap(std::vector<Id>& anyNodeIds, std::vector<PatternNode>& patternNodes)
{
    std::map<Id, std::vector<Id>> any2Inputs;
    std::for_each(patternNodes.begin(), patternNodes.end(), [&anyNodeIds, &any2Inputs](PatternNode& patternNode) {
        for (const auto& input : patternNode.inputs_) {
            auto iter = std::find(anyNodeIds.cbegin(), anyNodeIds.cend(), input);
            if (iter != anyNodeIds.end()) {
                any2Inputs[*iter].push_back(patternNode.id_);
            }
        }
    });

    return any2Inputs;
}

bool MatchAnyFromNodeInputs(const ge::Node* node, const ge::Node* any)
{
    bool found = false;
    HIAI_EXPECT_NOT_NULL_R(node, found);
    node->ROLE(NodeWalker).ListInNodes([&any, &found](ge::Node& in) {
        if (any == &in) {
            found = true;
        }
        return SUCCESS;
    });

    return found;
}

bool VerifyHomology(std::map<Id, std::vector<Id>>& any2Inputs, PatternMapping* patternMapping)
{
    for (const auto& item : any2Inputs) {
        ge::Node* any = patternMapping->Node(item.first);
        for (const auto& id : item.second) {
            if (!MatchAnyFromNodeInputs(patternMapping->Node(id), any)) {
                return false;
            }
        }
    }

    return true;
}
} // namespace

bool PatternDefine::VerifyPattern() const
{
    if (patternNodes_.size() != patternNodeMap_.size()) {
        return false;
    }

    for (auto& patternNode : patternNodes_) {
        if (patternNodeMap_.count(patternNode.id_) == 0) {
            return false;
        }
    }
    for (auto& patternNode : patternNodeMap_) {
        if (patternNode.second == nullptr) {
            return false;
        }
    }

    return true;
}

bool PatternDefine::VerifyInputsHomology(PatternMapping* patternMapping)
{
    std::vector<Id> anyNodes = AllAnyNode(patternNodes_, patternInputs_);
    if (anyNodes.size() == 0) {
        return true;
    }

    std::map<Id, std::vector<Id>> any2Inputs = AnyNodeToInputsMap(anyNodes, patternNodes_);
    HIAI_EXPECT_TRUE_R(anyNodes.size() == any2Inputs.size(), false);

    return VerifyHomology(any2Inputs, patternMapping);
}

bool PatternDefine::VerifyOutNode(const ge::Node& node)
{
    if (patternOutput_.size() != 1) {
        return false;
    }

    if (patternNodeMap_.count(patternOutput_[0]) == 0) {
        return false;
    }

    return std::find(patternNodeMap_[patternOutput_[0]]->types_.begin(),
        patternNodeMap_[patternOutput_[0]]->types_.end(),
        node.ROLE(NodeSpec).Type()) != patternNodeMap_[patternOutput_[0]]->types_.end();
}

std::unique_ptr<PatternMapping> PatternDefine::BfsFromOutput(const ge::ComputeGraph& graph, const ge::Node& outputNode)
{
    std::vector<Id> patternNodeIdQueue = {patternOutput_[0]};

    std::unique_ptr<PatternMappingImpl> patternMapping = make_unique_nothrow<PatternMappingImpl>(graph);
    HIAI_EXPECT_NOT_NULL_R(patternMapping, nullptr);
    patternMapping->Insert(patternOutput_[0], &outputNode);
    patternMapping->SetOutNode(patternOutput_[0]);

    while (!patternNodeIdQueue.empty()) {
        Id currentId = patternNodeIdQueue.front();
        std::vector<Id> matchedId;
        HIAI_EXPECT_EXEC_R(MatchNode(currentId, patternNodeMap_, patternInputs_, patternMapping, matchedId), nullptr);

        patternNodeIdQueue.insert(patternNodeIdQueue.cend(), matchedId.cbegin(), matchedId.cend());
        patternNodeIdQueue.erase(patternNodeIdQueue.cbegin());
    }
    HIAI_EXPECT_TRUE_R(VerifyInputsHomology(patternMapping.get()), nullptr);

    return std::move(patternMapping);
}

std::unique_ptr<PatternMapping> PatternDefine::Match(const ge::Node& outputNode, const ge::ComputeGraph& graph)
{
    HIAI_EXPECT_TRUE_R(VerifyPattern(), nullptr); // verify if pattern is valid

    HIAI_EXPECT_TRUE_R(VerifyOutNode(outputNode), nullptr);

    return BfsFromOutput(graph, outputNode);
}
} // namespace hiai